/* Copyright (C) 2004 Stefan Bellon <sbellon@sbellon.de>
 *
 * This file is part of RemotePrinterFS.
 *
 * RemotePrinterFS is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * RemotePrinterFS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <socklib.h>
#include <inetlib.h>
#include <netdb.h>

#include "config.h"
#include "errors.h"
#include "syslogf.h"
#include "utils.h"

#define HOST_NAME_MAX   31
#define USER_NAME_MAX   31
#define FILE_NAME_MAX  131
#define MIN_LOCAL_PORT 721
#define MAX_LOCAL_PORT 731
#define FILENAME       "<Wimp$Scrap>"

typedef enum {datafirst, controlfirst} order_t;

static os_error           server_error;
static struct sockaddr_in remote_addr;
static int                socket_type;
static char               *queue;
static char               *user;
static char               *local_hostname;
static order_t            order;
static bool               buffer_on_disk;
static char               format;
static int                filesize;
static int                sendsize;
static int                counter   = 0;

static os_error const *
set_value(char *key, char *value)
{
    syslogf(LOG_DEBUG, "Found key '%s', value '%s'\n", key, value);

    if (!strcasecmp(key, "network")) {
        if (!strcasecmp(value, "tcp"))
            socket_type = SOCK_STREAM;
        else if (!strcasecmp(value, "udp"))
            socket_type = SOCK_DGRAM;
        else {
            syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
            return ERR(err_InvalidSpecialField);
        }

    } else if (!strcasecmp(key, "address") || !strcasecmp(key, "host")) {
        struct hostent *hent;
        hent = gethostbyname(value);
        if (!hent) {
            syslogf(LOG_ERR, "Hostname lookup failed for: %s", value);
            return ERR(err_NameLookup);
        }
        memcpy(&(remote_addr.sin_addr), hent->h_addr_list[0], hent->h_length);

    } else if (!strcasecmp(key, "service") || !strcasecmp(key, "port")) {
        if (aredigit(value)) {
            remote_addr.sin_port = htons((short) strtoul(value, NULL, 0));
        } else {
            struct servent *se =
              getservbyname(value, socket_type == SOCK_STREAM ? "tcp" : "udp");
            if (!se) {
                syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
                return ERR(err_UnknownService);
            }
            remote_addr.sin_port = se->s_port;
        }

    } else if (!strcasecmp(key, "queue")) {
        queue = strdup(value);

    } else if (!strcasecmp(key, "user")) {
        user = strdup(value);

    } else if (!strcasecmp(key, "local")) {
        local_hostname = strdup(value);

    } else if (!strcasecmp(key, "sendfirst")) {
        if (!strcasecmp(value, "controlfile"))
            order = controlfirst;
        else if (!strcasecmp(value, "datafile"))
            order = datafirst;
        else {
            syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
            return ERR(err_InvalidSpecialField);
        }

    } else if (!strcasecmp(key, "buffer")) {
        if (!strcasecmp(value, "yes"))
            buffer_on_disk = true;
        else if (!strcasecmp(value, "no"))
            buffer_on_disk = false;
        else {
            syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
            return ERR(err_InvalidSpecialField);
        }

    } else if (!strcasecmp(key, "format")) {
        if (strchr("cdfglnoprtv", value[0]) && value[1] == 0)
            format = value[0];
        else {
            syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
            return ERR(err_InvalidSpecialField);
        }

    } else if (!strcasecmp(key, "sendsize")) {
        if (aredigit(value)) {
            sendsize = (int) strtol(value, NULL, 10);
            if (sendsize < 0) {
                syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
                return ERR(err_InvalidSpecialField);
            }
        } else {
            syslogf(LOG_ERR, "Invalid value for key '%s': %s", key, value);
            return ERR(err_InvalidSpecialField);
        }
    }
    
    return NULL;
}

static char*
get_hostname(void)
{
    char hostname[HOST_NAME_MAX+1];
    
    gethostname(hostname, HOST_NAME_MAX+1);
    
    return strdup(hostname);
}

static void
free_parameters(void)
{
    if (queue) {
        free(queue);
        queue = NULL;
    }
    if (user) {
        free(user);
        user = NULL;
    }
    if (local_hostname) {
        free(local_hostname);
        local_hostname = NULL;
    }
}

static os_error const *
set_parameters(char *special_field, int length, int filetype)
{
    /* Free old values */
    free_parameters();
    
    /* Set default values */
    memset(&remote_addr, 0, sizeof(remote_addr));
    remote_addr.sin_family      = AF_INET;
    remote_addr.sin_port        = htons(515);
    remote_addr.sin_addr.s_addr = 0x00000000;
    socket_type                 = SOCK_STREAM;
    queue                       = strdup("auto");
    user                        = strdup("root");
    local_hostname              = get_hostname();
    order                       = controlfirst;
    buffer_on_disk              = false;
    format                      = '\0';
    filesize                    = length;
    sendsize                    = 0;

    /* Set user values */
    error = parse_special_field(special_field, set_value);

    /* If no format is specified in the special field, look at the
       filename and make a guess ... */
    if (format == '\0')
    {
        switch (filetype)
        {
        case 0xce4: /* DVI */
             format = 'd';
             break;
        case 0xfff: /* Text */
             format = 'f';
             break;
        case 0xff5: /* PoScript */
             format = 'o';
             break;
        case 0x000:
        case 0xff4: /* Printout */
             format = 'l';
             break;
        default:    /* The rest */
             syslogf(LOG_WARNING,
                     "Found filetype &%X, please inform <sbellon@sbellon.de>!",
                     filetype);
             format = 'l';
        }
    }
    
    /* If data file is to be sent first, we _have_ to buffer on disk
       in order to be able to find out the size of the data */
    if (order == datafirst && filesize < 0)
        buffer_on_disk = true;

    return error;
}

static int
lpd_connect(void)
{
    int s = socket(AF_INET, socket_type, 0);
    if (s == -1) {
        error = ERR(err_Socket);
        syslogf(LOG_ERR, "%s", error->errmess);
        return -1;
    }

    /* Set local address. */
    struct sockaddr_in local_addr;
    memset(&local_addr, 0, sizeof(local_addr));
    local_addr.sin_family      = AF_INET;
    local_addr.sin_addr.s_addr = 0x00000000;

    for (int port = MIN_LOCAL_PORT; port <= MAX_LOCAL_PORT; ++port) {
        local_addr.sin_port = htons(port);
        if (bind(s, (struct sockaddr *) &local_addr, sizeof(local_addr)) == -1)
        {
            if (port == MAX_LOCAL_PORT) {
                close(s);
                error = ERR(err_LocalAddress);
                syslogf(LOG_ERR, "%s", error->errmess);
                return -1;
            }
        }else {
            break;
        }
    }

    /* Set remote address. */
    if (connect(s, (struct sockaddr*) &remote_addr, sizeof(remote_addr)) == -1)
    {
        close(s);
        error = ERR(err_RemoteAddress);
        syslogf(LOG_ERR, "%s", error->errmess);
        return -1;
    }

    return s;
}

static os_error const *
lpd_server_error(int s)
{
    char buf;
    int num;
    
    syslogf(LOG_DEBUG|LOG_INFO, "Waiting for server answer ...");
    num = socketread(s, &buf, 1);
    if (num != 1) {
        server_error.errnum = ERR(err_ServerError)->errnum;
        sprintf(server_error.errmess,
                "Server error signaled (size of response: %i)", num);
        syslogf(LOG_ERR, "%s", server_error.errmess);
        return &server_error;
    }

    if (buf == 0) {
        syslogf(LOG_DEBUG, "Server answered with OK\n");
        return NULL;
    }

    server_error.errnum = ERR(err_ServerError)->errnum;
    sprintf(server_error.errmess, "Server error signaled (code=%x)", buf);
    syslogf(LOG_ERR, "%s", server_error.errmess);

    return &server_error;
}

static os_error const *
lpd_send_controlfile(int s)
{
    char cf[FILE_NAME_MAX+2*HOST_NAME_MAX+USER_NAME_MAX+15];
    char cflen[HOST_NAME_MAX+20];

    /* Construct control file */
    sprintf(cf,
            "P%.*s\012"
            "H%.*s\012"
            "%cdfA%03i%.*s\012"
            "N%.*s\012",
            USER_NAME_MAX, user,
            HOST_NAME_MAX, local_hostname,
            format, counter, HOST_NAME_MAX, local_hostname,
            FILE_NAME_MAX, FILENAME);

    /* Send length of control file */
    syslogf(LOG_DEBUG, "Sending control file length\n");
    sprintf(cflen, "\002%u cfA%03i%.*s\012",
            strlen(cf), counter, HOST_NAME_MAX, local_hostname);
    socketwrite(s, cflen, strlen(cflen));
    error = lpd_server_error(s);
    if (error) {
        syslogf(LOG_ERR, "%s", error->errmess);
        return error;
    }

    /* Send already constructed control file */
    syslogf(LOG_DEBUG, "Sending control file\n");
    socketwrite(s, cf, strlen(cf));
    socketwrite(s, "\000", 1);
    error = lpd_server_error(s);
    if (error) {
        syslogf(LOG_ERR, "%s", error->errmess);
        return error;
    }
    
    return NULL;
}

static os_error const *
lpd_start_job(int s)
{
    char buf[2*FILE_NAME_MAX+HOST_NAME_MAX];
    
    /* Select queue */
    syslogf(LOG_DEBUG, "Selecting queue\n");
    sprintf(buf, "\002%.*s\012", FILE_NAME_MAX, queue);
    socketwrite(s, buf, strlen(buf));
    error = lpd_server_error(s);
    if (error) {
        syslogf(LOG_ERR, "%s", error->errmess);
        return error;
    }

    if (order == controlfirst) {
        /* Send control file */
        error = lpd_send_controlfile(s);
        if (error) {
            syslogf(LOG_ERR, "%s", error->errmess);
            return error;
        }
    }
    
    /* Send data */
    syslogf(LOG_DEBUG, "Sending data command\n");
    sprintf(buf, "\003%u dfA%03i%.*s\012",
            filesize > 0 ? (unsigned) filesize : (unsigned) sendsize,
            counter,
            HOST_NAME_MAX, local_hostname);
    socketwrite(s, buf, strlen(buf));
    error = lpd_server_error(s);
    if (error) {
        syslogf(LOG_ERR, "%s", error->errmess);
        return error;
    }

    syslogf(LOG_DEBUG, "Server prepared and ready to receive data\n");

    return NULL;
}

static void
lpd_abort_job_and_close(int s)
{
    socketwrite(s, "\001\012", 2);
    close(s);
}

int
lpd_open(char *special_field, int length, int filetype)
{
    int s;

    syslogf(LOG_INFO, "Starting print job (%s) ...", special_field);
    syslogf(LOG_DEBUG, "File length to print is %i bytes.", length);

    error = set_parameters(special_field, length, filetype);
    if (error)
        return -1;

    /* We don't care whether we have to abort later on, just increase
       the job counter. If we abort, then not every job number is used. */
    if (++counter > 999)
        counter = 0;

    s = lpd_connect();
    if (s >= 0) {
        error = lpd_start_job(s);
        if (error) {
            syslogf(LOG_ERR, "%s", error->errmess);
            lpd_abort_job_and_close(s);
            return -1;
        }
    }
    
    return s;
}

int
lpd_write(int s, char *buf, int len)
{
    return socketwrite(s, buf, len);
}

int
lpd_close(int s)
{
    if (order == controlfirst) {
        if (filesize >= 0) {
            socketwrite (s, "\000", 1);
            /* Just read the error code one last time without consequences */
            error = lpd_server_error(s);
            if (error) {
                close(s);
                syslogf(LOG_ERR, "%s", error->errmess);
                return -1;
            }
        } else {
            /* Terminate connection as we're relying on the grace of 
               the server anyway since we didn't send a filesize.
               Do not send the "\000" as otherwise an additional sheet
               of paper is output for some printers. */
        }
    } else {
        /* We have to send the control file now */
        if (filesize >= 0) {
            socketwrite (s, "\000", 1);
            /* Read the error code for the data transfer */
            error = lpd_server_error(s);
            if (error) {
                syslogf(LOG_ERR, "%s", error->errmess);
                lpd_abort_job_and_close(s);
                return -1;
            }
            error = lpd_send_controlfile(s);
            if (error) {
                syslogf(LOG_ERR, "%s", error->errmess);
                lpd_abort_job_and_close(s);
                return -1;
            }
        } else {
            /* This must not occur as there's no way it can work! */
            error = ERR(err_InternalError);
            syslogf(LOG_ERR, "%s", error->errmess);
            return -1;
        }
    }

    syslogf(LOG_INFO, "Print job finished.");

    free_parameters();

    return close(s);
}

char *
lpd_get_spool_filename(char *special_field, int length)
{
    char *filename = NULL;

    syslogf(LOG_DEBUG, "Spool filename requested (%s) ...", special_field);
    syslogf(LOG_DEBUG, "File length to print is %i bytes.", length);

    set_parameters(special_field, length, 0);

    if (buffer_on_disk == true) {
        filename = malloc(7);
        if (!filename)
            syslogf(LOG_CRIT, "malloc failed, no memory allocated");
        sprintf(filename, "LPD%03i", counter);
        syslogf(LOG_DEBUG, "Spool filename is \"%s\".", filename);
    }
    
    return filename;
}
